import torch

a = torch.tensor([
    [[2],[3]],
    [[4],[5]],
    [[6],[7]]
])
print(a.shape)
b = torch.squeeze(a,2)
print(torch.squeeze(a,2).shape)
print(b)
print(torch.unsqueeze(b,3))